This notebook will contain the description and code for various type of Autoencoders.
we are going to use an anime face dataset and our aim is to generate or reproduce anime faces
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
''' Link to explain how to download Datasets from kaggle https://www.kaggle.com/general/74235'''
!pip install -q kaggle
!mkdir ~/.kaggle
!cp '/content/drive/My Drive/Kaggle/kaggle.json' ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json
mkdir: cannot create directory ‘/root/.kaggle’: File exists
%%time
!kaggle datasets download -d splcher/animefacedataset -p dataset
!unzip -q dataset/animefacedataset.zip -d dataset/animefacedataset
!rm dataset/animefacedataset.zip
Downloading animefacedataset.zip to dataset 97% 383M/395M [00:02<00:00, 189MB/s] 100% 395M/395M [00:02<00:00, 145MB/s] replace dataset/animefacedataset/images/0_2000.jpg? [y]es, [n]o, [A]ll, [N]one, [r]ename: N CPU times: user 86.5 ms, sys: 31.2 ms, total: 118 ms Wall time: 10.8 s
import os
dataset_dir = "dataset/animefacedataset/images"
image_files = [os.path.join(dataset_dir, x) for x in os.listdir(dataset_dir)]
len(image_files)
63565
from matplotlib import pyplot as plt
import numpy as np
import math
import cv2
def plot_images(images):
n_col = 8
n_row = int(math.ceil(len(images) / n_col))
_, axs = plt.subplots(n_row, n_col, figsize=(12, 12))
axs = axs.flatten()
for img, ax in zip(images, axs):
if os.path.exists(img):
img = cv2.imread(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
ax.imshow(img)
plt.show()
from mpl_toolkits.axes_grid1 import ImageGrid
def plot_images(images, n_col=8):
n_row = int(math.ceil(len(images) / n_col))
fig = plt.figure(figsize=(12., 12.))
grid = ImageGrid(fig, 111, # similar to subplot(111)
nrows_ncols=(n_row, n_col), # creates 2x2 grid of axes
axes_pad=0.0, # pad between axes in inch.
)
for ax, img in zip(grid, images):
# Iterating over the grid returns the Axes.
if type(img) == str and os.path.exists(img):
img = cv2.imread(img)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (64, 64)) # Reshaping for visualization
ax.imshow(img)
plt.show()
plot_images(image_files[0:16])
from sklearn.model_selection import train_test_split
images_files_train, images_files_test = train_test_split(image_files, test_size=0.3, shuffle=True)
print("Train:", len(images_files_train))
print("Test:", len(images_files_test))
Train: 44495 Test: 19070
def read_image_file(imgfile):
img = cv2.imread(imgfile)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, (64, 64)) # Reshaping for visualization
return img.astype(np.uint8)
images_train = np.array([read_image_file(x) for x in images_files_train])
images_train.shape
(44495, 64, 64, 3)
images_test = np.array([read_image_file(x) for x in images_files_test])
images_test.shape
(19070, 64, 64, 3)
images_shape = images_test[0].shape
total_pixels = np.size(images_test[0])
images_shape, total_pixels
((64, 64, 3), 12288)
These type of autoencoders contains dense layers as encoder and decoder
Lets try to build an autoencoder using only dense layers to reproduce same input image
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint
model_file = 'model_ae_dnn.h5'
model = keras.Sequential(name="my_sequential")
model.add(keras.Input(shape=images_shape, dtype=tf.int8))
model.add(layers.Flatten())
model.add(layers.Dense(128, activation="relu", name="encoder_layer_1"))
model.add(layers.Dense(64, activation="relu", name="encoder_layer_2"))
model.add(layers.Dense(32, activation="relu", name="encoder_layer_3"))
model.add(layers.Dense(16, activation="relu", name="encoder_layer_4"))
model.add(layers.Dense(8, name="code"))
model.add(layers.Dense(16, activation="relu", name="decoder_layer_1"))
model.add(layers.Dense(32, activation="relu", name="decoder_layer_2"))
model.add(layers.Dense(64, activation="relu", name="decoder_layer_3"))
model.add(layers.Dense(128, activation="relu", name="decoder_layer_4"))
model.add(layers.Dense(total_pixels, activation="relu", name="final_layer"))
model.add(layers.Reshape(images_shape))
checkpoint = ModelCheckpoint(model_file, verbose=0, monitor='val_loss', save_best_only=True, mode='auto')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
# tf.keras.losses.MeanAbsoluteError()
# tf.keras.losses.MeanSquaredError()
# tf.keras.losses.kullback_leibler_divergence()
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.MeanSquaredError(),
metrics=['mse']
)
model.summary()
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model. Model: "my_sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten_9 (Flatten) (None, 12288) 0 _________________________________________________________________ encoder_layer_1 (Dense) (None, 128) 1572992 _________________________________________________________________ encoder_layer_2 (Dense) (None, 64) 8256 _________________________________________________________________ encoder_layer_3 (Dense) (None, 32) 2080 _________________________________________________________________ encoder_layer_4 (Dense) (None, 16) 528 _________________________________________________________________ code (Dense) (None, 8) 136 _________________________________________________________________ decoder_layer_1 (Dense) (None, 16) 144 _________________________________________________________________ decoder_layer_2 (Dense) (None, 32) 544 _________________________________________________________________ decoder_layer_3 (Dense) (None, 64) 2112 _________________________________________________________________ decoder_layer_4 (Dense) (None, 128) 8320 _________________________________________________________________ final_layer (Dense) (None, 12288) 1585152 _________________________________________________________________ reshape_9 (Reshape) (None, 64, 64, 3) 0 ================================================================= Total params: 3,180,264 Trainable params: 3,180,264 Non-trainable params: 0 _________________________________________________________________
%%time
model.fit(images_train, images_train, batch_size=16, epochs=500, validation_split=0.2, callbacks=[checkpoint, early_stopping], shuffle=True)
model.save(model_file) # Save Best model to disk
Epoch 1/500 2225/2225 [==============================] - 16s 7ms/step - loss: 7745.4814 - mse: 7745.4814 - val_loss: 5789.5469 - val_mse: 5789.5469 Epoch 2/500 2225/2225 [==============================] - 15s 7ms/step - loss: 5094.6670 - mse: 5094.6670 - val_loss: 4754.9946 - val_mse: 4754.9946 Epoch 3/500 2225/2225 [==============================] - 15s 7ms/step - loss: 4244.2397 - mse: 4244.2397 - val_loss: 4173.7920 - val_mse: 4173.7920 Epoch 4/500 2225/2225 [==============================] - 15s 7ms/step - loss: 3777.8196 - mse: 3777.8196 - val_loss: 3882.5774 - val_mse: 3882.5774 Epoch 5/500 2225/2225 [==============================] - 15s 7ms/step - loss: 3416.8728 - mse: 3416.8728 - val_loss: 3642.4468 - val_mse: 3642.4468 Epoch 6/500 2225/2225 [==============================] - 15s 7ms/step - loss: 3163.1895 - mse: 3163.1895 - val_loss: 3392.8254 - val_mse: 3392.8254 Epoch 7/500 2225/2225 [==============================] - 15s 7ms/step - loss: 3011.4729 - mse: 3011.4729 - val_loss: 3302.9238 - val_mse: 3302.9238 Epoch 8/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2893.6626 - mse: 2893.6626 - val_loss: 3230.7539 - val_mse: 3230.7539 Epoch 9/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2801.2739 - mse: 2801.2739 - val_loss: 3111.6794 - val_mse: 3111.6794 Epoch 10/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2709.9768 - mse: 2709.9768 - val_loss: 3024.9519 - val_mse: 3024.9519 Epoch 11/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2628.3242 - mse: 2628.3242 - val_loss: 2966.7283 - val_mse: 2966.7283 Epoch 12/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2557.0449 - mse: 2557.0449 - val_loss: 2912.2588 - val_mse: 2912.2588 Epoch 13/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2502.5159 - mse: 2502.5159 - val_loss: 2850.7229 - val_mse: 2850.7229 Epoch 14/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2455.8535 - mse: 2455.8535 - val_loss: 2820.7717 - val_mse: 2820.7717 Epoch 15/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2415.6208 - mse: 2415.6208 - val_loss: 2782.8862 - val_mse: 2782.8862 Epoch 16/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2378.2502 - mse: 2378.2502 - val_loss: 2762.5403 - val_mse: 2762.5403 Epoch 17/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2352.1372 - mse: 2352.1372 - val_loss: 2739.2517 - val_mse: 2739.2517 Epoch 18/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2318.7959 - mse: 2318.7959 - val_loss: 2712.8945 - val_mse: 2712.8945 Epoch 19/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2296.4275 - mse: 2296.4275 - val_loss: 2706.2317 - val_mse: 2706.2317 Epoch 20/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2276.7935 - mse: 2276.7935 - val_loss: 2682.8230 - val_mse: 2682.8230 Epoch 21/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2259.1987 - mse: 2259.1987 - val_loss: 2671.7095 - val_mse: 2671.7095 Epoch 22/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2243.9900 - mse: 2243.9900 - val_loss: 2656.2014 - val_mse: 2656.2014 Epoch 23/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2225.9990 - mse: 2225.9990 - val_loss: 2628.5667 - val_mse: 2628.5667 Epoch 24/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2201.2419 - mse: 2201.2419 - val_loss: 2633.5908 - val_mse: 2633.5908 Epoch 25/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2187.1528 - mse: 2187.1528 - val_loss: 2603.9958 - val_mse: 2603.9958 Epoch 26/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2167.0247 - mse: 2167.0247 - val_loss: 2579.0828 - val_mse: 2579.0828 Epoch 27/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2147.0837 - mse: 2147.0837 - val_loss: 2565.5955 - val_mse: 2565.5955 Epoch 28/500 2225/2225 [==============================] - 14s 7ms/step - loss: 2131.2476 - mse: 2131.2476 - val_loss: 2553.0916 - val_mse: 2553.0916 Epoch 29/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2125.1418 - mse: 2125.1418 - val_loss: 2546.6794 - val_mse: 2546.6794 Epoch 30/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2113.3552 - mse: 2113.3552 - val_loss: 2535.7832 - val_mse: 2535.7832 Epoch 31/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2106.2405 - mse: 2106.2405 - val_loss: 2535.9009 - val_mse: 2535.9009 Epoch 32/500 2225/2225 [==============================] - 14s 7ms/step - loss: 2100.7708 - mse: 2100.7708 - val_loss: 2521.9089 - val_mse: 2521.9089 Epoch 33/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2093.4861 - mse: 2093.4861 - val_loss: 2529.3059 - val_mse: 2529.3059 Epoch 34/500 2225/2225 [==============================] - 14s 7ms/step - loss: 2086.1904 - mse: 2086.1904 - val_loss: 2506.7351 - val_mse: 2506.7351 Epoch 35/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2084.8545 - mse: 2084.8545 - val_loss: 2516.1262 - val_mse: 2516.1262 Epoch 36/500 2225/2225 [==============================] - 14s 7ms/step - loss: 2077.9331 - mse: 2077.9331 - val_loss: 2503.3787 - val_mse: 2503.3787 Epoch 37/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2072.5806 - mse: 2072.5806 - val_loss: 2494.9321 - val_mse: 2494.9321 Epoch 38/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2067.7234 - mse: 2067.7234 - val_loss: 2494.3889 - val_mse: 2494.3889 Epoch 39/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2062.2871 - mse: 2062.2871 - val_loss: 2483.0850 - val_mse: 2483.0850 Epoch 40/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2057.3281 - mse: 2057.3281 - val_loss: 2478.1245 - val_mse: 2478.1245 Epoch 41/500 2225/2225 [==============================] - 14s 7ms/step - loss: 2055.6375 - mse: 2055.6375 - val_loss: 2477.5952 - val_mse: 2477.5952 Epoch 42/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2047.7416 - mse: 2047.7416 - val_loss: 2474.7520 - val_mse: 2474.7520 Epoch 43/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2044.8254 - mse: 2044.8254 - val_loss: 2487.7310 - val_mse: 2487.7310 Epoch 44/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2041.1714 - mse: 2041.1714 - val_loss: 2473.2947 - val_mse: 2473.2947 Epoch 45/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2038.4335 - mse: 2038.4335 - val_loss: 2460.0046 - val_mse: 2460.0046 Epoch 46/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2036.6165 - mse: 2036.6165 - val_loss: 2459.0994 - val_mse: 2459.0994 Epoch 47/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2034.9250 - mse: 2034.9250 - val_loss: 2459.0425 - val_mse: 2459.0425 Epoch 48/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2031.9937 - mse: 2031.9937 - val_loss: 2445.1523 - val_mse: 2445.1523 Epoch 49/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2029.2157 - mse: 2029.2157 - val_loss: 2448.8923 - val_mse: 2448.8923 Epoch 50/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2027.5762 - mse: 2027.5762 - val_loss: 2450.2930 - val_mse: 2450.2930 Epoch 51/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2024.1421 - mse: 2024.1421 - val_loss: 2455.2170 - val_mse: 2455.2170 Epoch 52/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2020.7971 - mse: 2020.7971 - val_loss: 2451.8428 - val_mse: 2451.8428 Epoch 53/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2016.3461 - mse: 2016.3461 - val_loss: 2430.4321 - val_mse: 2430.4321 Epoch 54/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2011.9696 - mse: 2011.9696 - val_loss: 2423.0098 - val_mse: 2423.0098 Epoch 55/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2008.6443 - mse: 2008.6443 - val_loss: 2413.8279 - val_mse: 2413.8279 Epoch 56/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2006.3228 - mse: 2006.3228 - val_loss: 2420.9802 - val_mse: 2420.9802 Epoch 57/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2005.4141 - mse: 2005.4141 - val_loss: 2434.0698 - val_mse: 2434.0698 Epoch 58/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2002.9851 - mse: 2002.9851 - val_loss: 2432.4451 - val_mse: 2432.4451 Epoch 59/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2001.9885 - mse: 2001.9885 - val_loss: 2427.0110 - val_mse: 2427.0110 Epoch 60/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1999.6281 - mse: 1999.6281 - val_loss: 2413.1631 - val_mse: 2413.1631 Epoch 61/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1999.6019 - mse: 1999.6019 - val_loss: 2405.9392 - val_mse: 2405.9392 Epoch 62/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1997.6655 - mse: 1997.6655 - val_loss: 2401.0386 - val_mse: 2401.0386 Epoch 63/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1996.4545 - mse: 1996.4545 - val_loss: 2411.3904 - val_mse: 2411.3904 Epoch 64/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1995.7986 - mse: 1995.7986 - val_loss: 2403.6282 - val_mse: 2403.6282 Epoch 65/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1994.1902 - mse: 1994.1902 - val_loss: 2420.5930 - val_mse: 2420.5930 Epoch 66/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1993.4651 - mse: 1993.4651 - val_loss: 2401.1953 - val_mse: 2401.1953 Epoch 67/500 2225/2225 [==============================] - 15s 7ms/step - loss: 2024.5470 - mse: 2024.5470 - val_loss: 2400.9795 - val_mse: 2400.9795 Epoch 68/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1989.7943 - mse: 1989.7943 - val_loss: 2400.9592 - val_mse: 2400.9592 Epoch 69/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1984.5483 - mse: 1984.5483 - val_loss: 2392.7205 - val_mse: 2392.7205 Epoch 70/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1987.0037 - mse: 1987.0037 - val_loss: 2406.4531 - val_mse: 2406.4531 Epoch 71/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1987.8333 - mse: 1987.8333 - val_loss: 2407.9504 - val_mse: 2407.9504 Epoch 72/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1985.6984 - mse: 1985.6984 - val_loss: 2397.6138 - val_mse: 2397.6138 Epoch 73/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1985.7360 - mse: 1985.7360 - val_loss: 2416.3306 - val_mse: 2416.3306 Epoch 74/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1985.3004 - mse: 1985.3004 - val_loss: 2391.9360 - val_mse: 2391.9360 Epoch 75/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1982.4799 - mse: 1982.4799 - val_loss: 2402.4446 - val_mse: 2402.4446 Epoch 76/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1985.8153 - mse: 1985.8153 - val_loss: 2387.0242 - val_mse: 2387.0242 Epoch 77/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1979.2655 - mse: 1979.2655 - val_loss: 2397.2800 - val_mse: 2397.2800 Epoch 78/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1978.4922 - mse: 1978.4922 - val_loss: 2398.9546 - val_mse: 2398.9546 Epoch 79/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1978.4847 - mse: 1978.4847 - val_loss: 2398.9705 - val_mse: 2398.9705 Epoch 80/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1978.7698 - mse: 1978.7698 - val_loss: 2404.4834 - val_mse: 2404.4834 Epoch 81/500 2225/2225 [==============================] - 15s 7ms/step - loss: 1978.0459 - mse: 1978.0459 - val_loss: 2407.9304 - val_mse: 2407.9304 CPU times: user 19min 15s, sys: 2min 14s, total: 21min 29s Wall time: 20min 4s
!mkdir -p drive/MyDrive/datasets/autoencoder/models_animefaces
!cp model_ae_dnn.h5 drive/MyDrive/datasets/autoencoder/models_animefaces
!ls -lh drive/MyDrive/datasets/autoencoder/models_animefaces
total 74M -rw------- 1 root root 6.0M Jun 5 13:22 model_ae_cnn.h5 -rw------- 1 root root 37M Jun 5 14:56 model_ae_dnn.h5 -rw------- 1 root root 31M Jun 5 10:41 model_ae_lstm.h5
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_dnn.h5'
# model.load_weights(model_file) # Load best model
model = tf.keras.models.load_model(model_file) # Load entire model
model.evaluate(images_test, images_test, batch_size=8, verbose=True)
2384/2384 [==============================] - 7s 3ms/step - loss: 2376.7542 - mse: 2376.7542
[2376.754150390625, 2376.754150390625]
def display_accuracy(model, image_actual, n_col=4, text=""):
print("=================================== %s ===============================" % text)
image_generated = model.predict(image_actual, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
images_side_by_side = np.concatenate([image_actual, image_generated], axis=2)
plot_images(images_side_by_side, n_col=n_col)
images_to_display = 16
display_accuracy(model, images_train[:images_to_display], text="Train Output")
display_accuracy(model, images_test[:images_to_display], text="Prediction Output")
=================================== Train Output ===============================
=================================== Prediction Output ===============================
from tensorflow import keras
# Layers to be used
layers = [keras.Input(shape=images_shape, dtype=tf.int8)]
layers.extend(model.layers[:6])
model_code_generator = keras.Sequential(layers)
model_code_generator.build((None, images_shape[0], images_shape[1], images_shape[2]))
for layer in model_code_generator.layers:
if list(filter(lambda x: x in layer.name, ['flatten', 'reshape'])):
continue
assert all([np.array_equal(layer.get_weights()[0], model.get_layer(layer.name).get_weights()[0]),
np.array_equal(layer.get_weights()[1], model.get_layer(layer.name).get_weights()[1])]), "%s weights not same" % layer.name
model_code_generator.summary()
WARNING:tensorflow:Please add `keras.layers.InputLayer` instead of `keras.Input` to Sequential model. `keras.Input` is intended to be used by Functional model. Model: "sequential_6" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= flatten_9 (Flatten) (None, 12288) 0 _________________________________________________________________ encoder_layer_1 (Dense) (None, 128) 1572992 _________________________________________________________________ encoder_layer_2 (Dense) (None, 64) 8256 _________________________________________________________________ encoder_layer_3 (Dense) (None, 32) 2080 _________________________________________________________________ encoder_layer_4 (Dense) (None, 16) 528 _________________________________________________________________ code (Dense) (None, 8) 136 ================================================================= Total params: 1,583,992 Trainable params: 1,583,992 Non-trainable params: 0 _________________________________________________________________
# imgs = model_code_generator.predict(images_test[:4], batch_size=8, verbose=False).astype(np.uint8)
# plot_images(imgs, n_col=8)
# imgs = model.predict(images_test[:4], batch_size=8, verbose=False).astype(np.uint8)
# plot_images(imgs, n_col=8)
codes = model_code_generator.predict(images_test[:16], batch_size=8, verbose=False)
codes.shape
(16, 8)
print(codes[0].tolist())
print(codes[1].tolist())
print(codes[2].tolist())
[1215.6807861328125, 1890.59423828125, -1772.0360107421875, 381.75714111328125, -574.5813598632812, 311.3032531738281, -895.1006469726562, 616.4902954101562] [-1336.106689453125, -1303.230712890625, -1296.32275390625, 246.3556671142578, -646.2127075195312, -150.3787384033203, -323.1648864746094, -382.8296813964844] [703.5673217773438, 599.2479858398438, -926.3267822265625, -22.82808494567871, -1733.0594482421875, 1482.5084228515625, -673.0850219726562, -205.01351928710938]
code_stats = {
"min" : np.min(codes),
"max" : np.max(codes),
"mean": np.mean(codes),
"std": np.std(codes)
}
code_stats
{'max': 2597.495, 'mean': -68.190186, 'min': -3308.1372, 'std': 1030.903}
But we need to remove some extra layers before that, now we know that code layer has 8 neurons. So we are going to generate some random 8 numbers and will pass it to out decoder layer
import tensorflow as tf
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_dnn.h5'
model = tf.keras.models.load_model(model_file) # Load entire model
# model.summary()
from tensorflow import keras
model_generator = keras.Sequential(model.layers[6:])
model_generator.build((None, 8))
model_generator.summary()
Model: "sequential_7" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= decoder_layer_1 (Dense) (None, 16) 144 _________________________________________________________________ decoder_layer_2 (Dense) (None, 32) 544 _________________________________________________________________ decoder_layer_3 (Dense) (None, 64) 2112 _________________________________________________________________ decoder_layer_4 (Dense) (None, 128) 8320 _________________________________________________________________ final_layer (Dense) (None, 12288) 1585152 _________________________________________________________________ reshape_9 (Reshape) (None, 64, 64, 3) 0 ================================================================= Total params: 1,596,272 Trainable params: 1,596,272 Non-trainable params: 0 _________________________________________________________________
import numpy as np
inputs = np.random.normal(code_stats['mean'], code_stats['std'], (16, 8))
# inputs = codes
image_generated = model_generator.predict(inputs, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
plot_images(image_generated, n_col=8)
This would be similar to Dense n/w as desribed above, but we will use LSTM layers this time
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint
model_file = 'model_ae_lstm.h5'
model = keras.Sequential(name="my_sequential")
model.add(tf.keras.layers.InputLayer(input_shape=images_shape))
model.add(layers.Reshape((images_shape[0], images_shape[1] * images_shape[2])))
model.add(tf.keras.layers.LSTM(64, activation='tanh', return_sequences=True, name="encoder_layer_1"))
model.add(tf.keras.layers.LSTM(32, activation='tanh', return_sequences=True, name="encoder_layer_2"))
model.add(tf.keras.layers.LSTM(16, activation='tanh', name="encoder_layer_3"))
model.add(layers.Dense(8, name="code"))
model.add(layers.Reshape((2, 4)))
model.add(tf.keras.layers.LSTM(16, activation='tanh', return_sequences=True, name="decoder_layer_1"))
model.add(tf.keras.layers.LSTM(32, activation='tanh', return_sequences=True, name="decoder_layer_2"))
model.add(tf.keras.layers.LSTM(64, activation='tanh', name="decoder_layer_3"))
model.add(layers.Dense(total_pixels, activation="relu", name="final_layer"))
model.add(layers.Reshape(images_shape))
checkpoint = ModelCheckpoint(model_file, verbose=0, monitor='val_loss', save_best_only=True, mode='auto')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.MeanSquaredError(),
metrics=['mse']
)
model.summary()
Model: "my_sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= reshape (Reshape) (None, 64, 192) 0 _________________________________________________________________ encoder_layer_1 (LSTM) (None, 64, 64) 65792 _________________________________________________________________ encoder_layer_2 (LSTM) (None, 64, 32) 12416 _________________________________________________________________ encoder_layer_3 (LSTM) (None, 16) 3136 _________________________________________________________________ code (Dense) (None, 8) 136 _________________________________________________________________ reshape_1 (Reshape) (None, 2, 4) 0 _________________________________________________________________ decoder_layer_1 (LSTM) (None, 2, 16) 1344 _________________________________________________________________ decoder_layer_2 (LSTM) (None, 2, 32) 6272 _________________________________________________________________ decoder_layer_3 (LSTM) (None, 64) 24832 _________________________________________________________________ final_layer (Dense) (None, 12288) 798720 _________________________________________________________________ reshape_2 (Reshape) (None, 64, 64, 3) 0 ================================================================= Total params: 912,648 Trainable params: 912,648 Non-trainable params: 0 _________________________________________________________________
%%time
model.fit(images_train, images_train, batch_size=16, epochs=500, validation_split=0.2, callbacks=[checkpoint, early_stopping], shuffle=True)
model.save(model_file) # Save Best model to disk
Epoch 1/500 2225/2225 [==============================] - 47s 15ms/step - loss: 17278.6152 - mse: 17278.6152 - val_loss: 10705.0771 - val_mse: 10705.0771 Epoch 2/500 2225/2225 [==============================] - 33s 15ms/step - loss: 9260.0801 - mse: 9260.0801 - val_loss: 8613.3975 - val_mse: 8613.3975 Epoch 3/500 2225/2225 [==============================] - 33s 15ms/step - loss: 8594.0107 - mse: 8594.0107 - val_loss: 8547.8838 - val_mse: 8547.8838 Epoch 4/500 2225/2225 [==============================] - 33s 15ms/step - loss: 8581.1357 - mse: 8581.1357 - val_loss: 8547.7734 - val_mse: 8547.7734 Epoch 5/500 2225/2225 [==============================] - 32s 14ms/step - loss: 8581.2842 - mse: 8581.2842 - val_loss: 8547.7871 - val_mse: 8547.7871 Epoch 6/500 2225/2225 [==============================] - 32s 14ms/step - loss: 8581.1934 - mse: 8581.1934 - val_loss: 8547.9473 - val_mse: 8547.9473 Epoch 7/500 2225/2225 [==============================] - 32s 14ms/step - loss: 8576.4355 - mse: 8576.4355 - val_loss: 8525.8848 - val_mse: 8525.8848 Epoch 8/500 2225/2225 [==============================] - 32s 14ms/step - loss: 8556.7031 - mse: 8556.7031 - val_loss: 8523.1943 - val_mse: 8523.1943 Epoch 9/500 2225/2225 [==============================] - 32s 15ms/step - loss: 8556.5117 - mse: 8556.5117 - val_loss: 8523.4219 - val_mse: 8523.4219 Epoch 10/500 2225/2225 [==============================] - 32s 14ms/step - loss: 8547.7285 - mse: 8547.7285 - val_loss: 8353.8105 - val_mse: 8353.8105 Epoch 11/500 2225/2225 [==============================] - 32s 14ms/step - loss: 8227.0508 - mse: 8227.0508 - val_loss: 8235.6152 - val_mse: 8235.6152 Epoch 12/500 2225/2225 [==============================] - 32s 14ms/step - loss: 8137.2378 - mse: 8137.2378 - val_loss: 8029.8247 - val_mse: 8029.8247 Epoch 13/500 2225/2225 [==============================] - 32s 14ms/step - loss: 7558.8306 - mse: 7558.8306 - val_loss: 7133.8076 - val_mse: 7133.8076 Epoch 14/500 2225/2225 [==============================] - 32s 14ms/step - loss: 6959.6494 - mse: 6959.6494 - val_loss: 6826.3511 - val_mse: 6826.3511 Epoch 15/500 2225/2225 [==============================] - 32s 15ms/step - loss: 6767.2334 - mse: 6767.2334 - val_loss: 6688.7368 - val_mse: 6688.7368 Epoch 16/500 2225/2225 [==============================] - 33s 15ms/step - loss: 6598.5073 - mse: 6598.5073 - val_loss: 6468.4165 - val_mse: 6468.4165 Epoch 17/500 2225/2225 [==============================] - 33s 15ms/step - loss: 6368.7534 - mse: 6368.7534 - val_loss: 6244.3647 - val_mse: 6244.3647 Epoch 18/500 2225/2225 [==============================] - 33s 15ms/step - loss: 6151.1299 - mse: 6151.1299 - val_loss: 6091.9585 - val_mse: 6091.9585 Epoch 19/500 2225/2225 [==============================] - 33s 15ms/step - loss: 6023.8628 - mse: 6023.8628 - val_loss: 5955.0894 - val_mse: 5955.0894 Epoch 20/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5918.6499 - mse: 5918.6499 - val_loss: 5845.9141 - val_mse: 5845.9141 Epoch 21/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5840.9941 - mse: 5840.9941 - val_loss: 5792.4121 - val_mse: 5792.4121 Epoch 22/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5797.5967 - mse: 5797.5967 - val_loss: 5735.8999 - val_mse: 5735.8999 Epoch 23/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5742.4775 - mse: 5742.4775 - val_loss: 5697.0938 - val_mse: 5697.0938 Epoch 24/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5682.7651 - mse: 5682.7651 - val_loss: 5656.0010 - val_mse: 5656.0010 Epoch 25/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5639.8833 - mse: 5639.8833 - val_loss: 5594.0493 - val_mse: 5594.0493 Epoch 26/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5588.4316 - mse: 5588.4316 - val_loss: 5547.0728 - val_mse: 5547.0728 Epoch 27/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5540.7373 - mse: 5540.7373 - val_loss: 5540.0312 - val_mse: 5540.0312 Epoch 28/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5475.1758 - mse: 5475.1758 - val_loss: 5439.8794 - val_mse: 5439.8794 Epoch 29/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5434.9487 - mse: 5434.9487 - val_loss: 5411.1914 - val_mse: 5411.1914 Epoch 30/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5402.4243 - mse: 5402.4243 - val_loss: 5433.1772 - val_mse: 5433.1772 Epoch 31/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5370.1484 - mse: 5370.1484 - val_loss: 5325.5493 - val_mse: 5325.5493 Epoch 32/500 2225/2225 [==============================] - 35s 16ms/step - loss: 5324.7983 - mse: 5324.7983 - val_loss: 5298.7197 - val_mse: 5298.7197 Epoch 33/500 2225/2225 [==============================] - 34s 15ms/step - loss: 5269.9653 - mse: 5269.9653 - val_loss: 5215.4131 - val_mse: 5215.4131 Epoch 34/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5208.8257 - mse: 5208.8257 - val_loss: 5173.6958 - val_mse: 5173.6958 Epoch 35/500 2225/2225 [==============================] - 34s 15ms/step - loss: 5146.5547 - mse: 5146.5547 - val_loss: 5082.8062 - val_mse: 5082.8062 Epoch 36/500 2225/2225 [==============================] - 33s 15ms/step - loss: 5085.5649 - mse: 5085.5649 - val_loss: 5106.9917 - val_mse: 5106.9917 Epoch 37/500 2225/2225 [==============================] - 32s 15ms/step - loss: 5068.1494 - mse: 5068.1494 - val_loss: 5122.3872 - val_mse: 5122.3872 Epoch 38/500 2225/2225 [==============================] - 32s 14ms/step - loss: 5073.0225 - mse: 5073.0225 - val_loss: 5051.2993 - val_mse: 5051.2993 Epoch 39/500 2225/2225 [==============================] - 32s 14ms/step - loss: 5058.7979 - mse: 5058.7979 - val_loss: 5053.2832 - val_mse: 5053.2832 Epoch 40/500 2225/2225 [==============================] - 32s 14ms/step - loss: 5037.4282 - mse: 5037.4282 - val_loss: 4976.8667 - val_mse: 4976.8667 Epoch 41/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4939.2544 - mse: 4939.2544 - val_loss: 4872.5215 - val_mse: 4872.5215 Epoch 42/500 2225/2225 [==============================] - 32s 15ms/step - loss: 4829.6265 - mse: 4829.6265 - val_loss: 4772.4727 - val_mse: 4772.4727 Epoch 43/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4742.4395 - mse: 4742.4395 - val_loss: 4697.9766 - val_mse: 4697.9766 Epoch 44/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4682.7646 - mse: 4682.7646 - val_loss: 4663.2148 - val_mse: 4663.2148 Epoch 45/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4615.2954 - mse: 4615.2954 - val_loss: 4563.9487 - val_mse: 4563.9487 Epoch 46/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4535.0283 - mse: 4535.0283 - val_loss: 4484.2603 - val_mse: 4484.2603 Epoch 47/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4435.7251 - mse: 4435.7251 - val_loss: 4380.4775 - val_mse: 4380.4775 Epoch 48/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4282.7090 - mse: 4282.7090 - val_loss: 4190.1650 - val_mse: 4190.1650 Epoch 49/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4159.7949 - mse: 4159.7949 - val_loss: 4125.0225 - val_mse: 4125.0225 Epoch 50/500 2225/2225 [==============================] - 32s 14ms/step - loss: 4088.7727 - mse: 4088.7727 - val_loss: 4081.5081 - val_mse: 4081.5081 Epoch 51/500 2225/2225 [==============================] - 32s 14ms/step - loss: 3796.1492 - mse: 3796.1492 - val_loss: 3567.1746 - val_mse: 3567.1746 Epoch 52/500 2225/2225 [==============================] - 32s 14ms/step - loss: 3517.8994 - mse: 3517.8994 - val_loss: 3468.4580 - val_mse: 3468.4580 Epoch 53/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3443.5837 - mse: 3443.5837 - val_loss: 3410.4177 - val_mse: 3410.4177 Epoch 54/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3397.5950 - mse: 3397.5950 - val_loss: 3372.3838 - val_mse: 3372.3838 Epoch 55/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3355.9170 - mse: 3355.9170 - val_loss: 3345.6899 - val_mse: 3345.6899 Epoch 56/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3325.5049 - mse: 3325.5049 - val_loss: 3329.6016 - val_mse: 3329.6016 Epoch 57/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3298.4011 - mse: 3298.4011 - val_loss: 3280.9995 - val_mse: 3280.9995 Epoch 58/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3226.9290 - mse: 3226.9290 - val_loss: 3194.6494 - val_mse: 3194.6494 Epoch 59/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3165.4016 - mse: 3165.4016 - val_loss: 3141.0205 - val_mse: 3141.0205 Epoch 60/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3126.9771 - mse: 3126.9771 - val_loss: 3118.6807 - val_mse: 3118.6807 Epoch 61/500 2225/2225 [==============================] - 33s 15ms/step - loss: 3091.4814 - mse: 3091.4814 - val_loss: 3055.6978 - val_mse: 3055.6978 Epoch 62/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2996.6882 - mse: 2996.6882 - val_loss: 2959.2883 - val_mse: 2959.2883 Epoch 63/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2914.3850 - mse: 2914.3850 - val_loss: 2868.0161 - val_mse: 2868.0161 Epoch 64/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2819.0962 - mse: 2819.0962 - val_loss: 2770.9736 - val_mse: 2770.9736 Epoch 65/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2718.3696 - mse: 2718.3696 - val_loss: 2690.1924 - val_mse: 2690.1924 Epoch 66/500 2225/2225 [==============================] - 32s 15ms/step - loss: 2582.4478 - mse: 2582.4478 - val_loss: 2506.6973 - val_mse: 2506.6973 Epoch 67/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2488.1816 - mse: 2488.1816 - val_loss: 2480.1951 - val_mse: 2480.1951 Epoch 68/500 2225/2225 [==============================] - 34s 15ms/step - loss: 2467.9194 - mse: 2467.9194 - val_loss: 2478.2026 - val_mse: 2478.2026 Epoch 69/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2451.6931 - mse: 2451.6931 - val_loss: 2444.1040 - val_mse: 2444.1040 Epoch 70/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2432.6746 - mse: 2432.6746 - val_loss: 2439.4841 - val_mse: 2439.4841 Epoch 71/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2416.6155 - mse: 2416.6155 - val_loss: 2406.9717 - val_mse: 2406.9717 Epoch 72/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2396.0857 - mse: 2396.0857 - val_loss: 2392.3086 - val_mse: 2392.3086 Epoch 73/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2386.6431 - mse: 2386.6431 - val_loss: 2387.3894 - val_mse: 2387.3894 Epoch 74/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2362.8672 - mse: 2362.8672 - val_loss: 2362.9543 - val_mse: 2362.9543 Epoch 75/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2347.3027 - mse: 2347.3027 - val_loss: 2347.9116 - val_mse: 2347.9116 Epoch 76/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2335.6294 - mse: 2335.6294 - val_loss: 2339.1584 - val_mse: 2339.1584 Epoch 77/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2324.3877 - mse: 2324.3877 - val_loss: 2324.0425 - val_mse: 2324.0425 Epoch 78/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2313.9614 - mse: 2313.9614 - val_loss: 2316.0813 - val_mse: 2316.0813 Epoch 79/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2308.7534 - mse: 2308.7534 - val_loss: 2303.6992 - val_mse: 2303.6992 Epoch 80/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2304.8267 - mse: 2304.8267 - val_loss: 2300.6892 - val_mse: 2300.6892 Epoch 81/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2295.8223 - mse: 2295.8223 - val_loss: 2318.0796 - val_mse: 2318.0796 Epoch 82/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2291.9814 - mse: 2291.9814 - val_loss: 2294.4219 - val_mse: 2294.4219 Epoch 83/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2285.5154 - mse: 2285.5154 - val_loss: 2280.4534 - val_mse: 2280.4534 Epoch 84/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2279.4600 - mse: 2279.4600 - val_loss: 2280.0750 - val_mse: 2280.0750 Epoch 85/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2270.3259 - mse: 2270.3259 - val_loss: 2264.2275 - val_mse: 2264.2275 Epoch 86/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2260.8247 - mse: 2260.8247 - val_loss: 2262.0779 - val_mse: 2262.0779 Epoch 87/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2251.9761 - mse: 2251.9761 - val_loss: 2253.4373 - val_mse: 2253.4373 Epoch 88/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2247.7524 - mse: 2247.7524 - val_loss: 2266.4182 - val_mse: 2266.4182 Epoch 89/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2242.2366 - mse: 2242.2366 - val_loss: 2243.9790 - val_mse: 2243.9790 Epoch 90/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2237.4939 - mse: 2237.4939 - val_loss: 2237.2122 - val_mse: 2237.2122 Epoch 91/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2228.7690 - mse: 2228.7690 - val_loss: 2229.9817 - val_mse: 2229.9817 Epoch 92/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2219.3467 - mse: 2219.3467 - val_loss: 2235.1709 - val_mse: 2235.1709 Epoch 93/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2211.1423 - mse: 2211.1423 - val_loss: 2209.8882 - val_mse: 2209.8882 Epoch 94/500 2225/2225 [==============================] - 32s 14ms/step - loss: 2203.2751 - mse: 2203.2751 - val_loss: 2204.6851 - val_mse: 2204.6851 Epoch 95/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2196.6384 - mse: 2196.6384 - val_loss: 2197.9944 - val_mse: 2197.9944 Epoch 96/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2191.7378 - mse: 2191.7378 - val_loss: 2192.9355 - val_mse: 2192.9355 Epoch 97/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2185.7500 - mse: 2185.7500 - val_loss: 2190.8318 - val_mse: 2190.8318 Epoch 98/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2181.7219 - mse: 2181.7219 - val_loss: 2187.3789 - val_mse: 2187.3789 Epoch 99/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2178.2847 - mse: 2178.2847 - val_loss: 2184.3501 - val_mse: 2184.3501 Epoch 100/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2171.7197 - mse: 2171.7197 - val_loss: 2172.8926 - val_mse: 2172.8926 Epoch 101/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2168.5957 - mse: 2168.5957 - val_loss: 2169.3792 - val_mse: 2169.3792 Epoch 102/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2165.2861 - mse: 2165.2861 - val_loss: 2173.1895 - val_mse: 2173.1895 Epoch 103/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2163.6592 - mse: 2163.6592 - val_loss: 2162.5911 - val_mse: 2162.5911 Epoch 104/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2157.7532 - mse: 2157.7532 - val_loss: 2177.4827 - val_mse: 2177.4827 Epoch 105/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2155.2971 - mse: 2155.2971 - val_loss: 2160.7231 - val_mse: 2160.7231 Epoch 106/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2152.8140 - mse: 2152.8140 - val_loss: 2153.8572 - val_mse: 2153.8572 Epoch 107/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2149.6211 - mse: 2149.6211 - val_loss: 2153.2310 - val_mse: 2153.2310 Epoch 108/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2148.5159 - mse: 2148.5159 - val_loss: 2154.9385 - val_mse: 2154.9385 Epoch 109/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2144.4446 - mse: 2144.4446 - val_loss: 2147.2644 - val_mse: 2147.2644 Epoch 110/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2142.0940 - mse: 2142.0940 - val_loss: 2144.1699 - val_mse: 2144.1699 Epoch 111/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2138.1355 - mse: 2138.1355 - val_loss: 2143.0337 - val_mse: 2143.0337 Epoch 112/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2135.3999 - mse: 2135.3999 - val_loss: 2140.4258 - val_mse: 2140.4258 Epoch 113/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2137.5798 - mse: 2137.5798 - val_loss: 2151.2314 - val_mse: 2151.2314 Epoch 114/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2135.3606 - mse: 2135.3606 - val_loss: 2139.0652 - val_mse: 2139.0652 Epoch 115/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2134.5042 - mse: 2134.5042 - val_loss: 2142.6350 - val_mse: 2142.6350 Epoch 116/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2129.5422 - mse: 2129.5422 - val_loss: 2135.6750 - val_mse: 2135.6750 Epoch 117/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2128.1431 - mse: 2128.1431 - val_loss: 2142.6777 - val_mse: 2142.6777 Epoch 118/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2124.3354 - mse: 2124.3354 - val_loss: 2126.6646 - val_mse: 2126.6646 Epoch 119/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2122.3333 - mse: 2122.3333 - val_loss: 2125.1572 - val_mse: 2125.1572 Epoch 120/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2123.1204 - mse: 2123.1204 - val_loss: 2126.4717 - val_mse: 2126.4717 Epoch 121/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2117.8911 - mse: 2117.8911 - val_loss: 2135.2971 - val_mse: 2135.2971 Epoch 122/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2120.0750 - mse: 2120.0750 - val_loss: 2123.0679 - val_mse: 2123.0679 Epoch 123/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2115.5303 - mse: 2115.5303 - val_loss: 2122.5771 - val_mse: 2122.5771 Epoch 124/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2112.2561 - mse: 2112.2561 - val_loss: 2121.9148 - val_mse: 2121.9148 Epoch 125/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2111.0742 - mse: 2111.0742 - val_loss: 2117.6472 - val_mse: 2117.6472 Epoch 126/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2110.0996 - mse: 2110.0996 - val_loss: 2116.0806 - val_mse: 2116.0806 Epoch 127/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2114.4026 - mse: 2114.4026 - val_loss: 2120.9204 - val_mse: 2120.9204 Epoch 128/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2108.3433 - mse: 2108.3433 - val_loss: 2113.1472 - val_mse: 2113.1472 Epoch 129/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2106.2334 - mse: 2106.2334 - val_loss: 2127.8787 - val_mse: 2127.8787 Epoch 130/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2104.1953 - mse: 2104.1953 - val_loss: 2113.1770 - val_mse: 2113.1770 Epoch 131/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2104.2295 - mse: 2104.2295 - val_loss: 2107.2539 - val_mse: 2107.2539 Epoch 132/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2100.5710 - mse: 2100.5710 - val_loss: 2118.6841 - val_mse: 2118.6841 Epoch 133/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2101.0947 - mse: 2101.0947 - val_loss: 2108.8284 - val_mse: 2108.8284 Epoch 134/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2097.8513 - mse: 2097.8513 - val_loss: 2108.7864 - val_mse: 2108.7864 Epoch 135/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2100.6631 - mse: 2100.6631 - val_loss: 2099.6821 - val_mse: 2099.6821 Epoch 136/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2098.4268 - mse: 2098.4268 - val_loss: 2106.6753 - val_mse: 2106.6753 Epoch 137/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2095.9814 - mse: 2095.9814 - val_loss: 2129.8088 - val_mse: 2129.8088 Epoch 138/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2096.1062 - mse: 2096.1062 - val_loss: 2100.0891 - val_mse: 2100.0891 Epoch 139/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2091.3901 - mse: 2091.3901 - val_loss: 2098.1531 - val_mse: 2098.1531 Epoch 140/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2089.8413 - mse: 2089.8413 - val_loss: 2098.5381 - val_mse: 2098.5381 Epoch 141/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2091.3118 - mse: 2091.3118 - val_loss: 2101.8169 - val_mse: 2101.8169 Epoch 142/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2094.5698 - mse: 2094.5698 - val_loss: 2110.2651 - val_mse: 2110.2651 Epoch 143/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2090.4109 - mse: 2090.4109 - val_loss: 2098.2056 - val_mse: 2098.2056 Epoch 144/500 2225/2225 [==============================] - 33s 15ms/step - loss: 2085.6516 - mse: 2085.6516 - val_loss: 2099.2151 - val_mse: 2099.2151 CPU times: user 1h 28min 50s, sys: 5min 56s, total: 1h 34min 46s Wall time: 1h 18min 44s
!mkdir -p drive/MyDrive/datasets/autoencoder/models_animefaces
!cp model_ae_lstm.h5 drive/MyDrive/datasets/autoencoder/models_animefaces
!ls -lh drive/MyDrive/datasets/autoencoder/models_animefaces
total 53M -rw------- 1 root root 6.0M Jun 5 13:22 model_ae_cnn.h5 -rw------- 1 root root 37M Jun 5 14:56 model_ae_dnn.h5 -rw------- 1 root root 11M Jun 6 07:45 model_ae_lstm.h5
# model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_lstm.h5'
# model.load_weights(model_file) # Load best model
model = tf.keras.models.load_model(model_file) # Load entire model
model.evaluate(images_test, images_test, batch_size=8, verbose=True)
2384/2384 [==============================] - 17s 6ms/step - loss: 2092.5908 - mse: 2092.5908
[2092.5908203125, 2092.5908203125]
def display_accuracy(model, image_actual, n_col=4, text=""):
print("=================================== %s ===============================" % text)
image_generated = model.predict(image_actual, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
images_side_by_side = np.concatenate([image_actual, image_generated], axis=2)
plot_images(images_side_by_side, n_col=n_col)
images_to_display = 16
display_accuracy(model, images_train[:images_to_display], text="Train Output")
display_accuracy(model, images_test[:images_to_display], text="Prediction Output")
=================================== Train Output ===============================
=================================== Prediction Output ===============================
from tensorflow import keras
# Layers to be used
layers = [tf.keras.layers.InputLayer(input_shape=images_shape)]
layers.extend(model.layers[:5])
model_code_generator = keras.Sequential(layers)
model_code_generator.build((None, images_shape[0], images_shape[1], images_shape[2]))
for layer in model_code_generator.layers:
if list(filter(lambda x: x in layer.name, ['flatten', 'reshape'])):
continue
assert all([np.array_equal(layer.get_weights()[0], model.get_layer(layer.name).get_weights()[0]),
np.array_equal(layer.get_weights()[1], model.get_layer(layer.name).get_weights()[1])]), "%s weights not same" % layer.name
model_code_generator.summary()
Model: "sequential_3" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= reshape (Reshape) (None, 64, 192) 0 _________________________________________________________________ encoder_layer_1 (LSTM) (None, 64, 64) 65792 _________________________________________________________________ encoder_layer_2 (LSTM) (None, 64, 32) 12416 _________________________________________________________________ encoder_layer_3 (LSTM) (None, 16) 3136 _________________________________________________________________ code (Dense) (None, 8) 136 ================================================================= Total params: 81,480 Trainable params: 81,480 Non-trainable params: 0 _________________________________________________________________
codes = model_code_generator.predict(images_test[:16], batch_size=8, verbose=False)
codes.shape
(16, 8)
print(codes[0].tolist())
print(codes[1].tolist())
print(codes[2].tolist())
[-0.26205557584762573, -0.33697181940078735, -0.3860102891921997, 0.07216334342956543, -0.08162796497344971, 0.21333178877830505, 0.44444242119789124, -0.7647983431816101] [-0.411038339138031, 0.4667309522628784, -0.6031712293624878, -0.2982359230518341, 0.4706922173500061, -0.06073388457298279, -0.12359648942947388, 0.189530611038208] [-0.05309383571147919, -0.9304435849189758, -0.7033981680870056, -0.44504719972610474, -0.6713100671768188, 1.3494720458984375, -0.4281803071498871, -0.707094132900238]
code_stats = {
"min" : np.min(codes),
"max" : np.max(codes),
"mean": np.mean(codes),
"std": np.std(codes)
}
code_stats
{'max': 1.349472, 'mean': -0.058297068, 'min': -1.4627534, 'std': 0.5171938}
But we need to remove some extra layers before that, now we know that code layer has 8 neurons. So we are going to generate some random 8 numbers and will pass it to out decoder layer
import tensorflow as tf
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_lstm.h5'
model = tf.keras.models.load_model(model_file) # Load entire model
# model.summary()
from tensorflow import keras
model_generator = keras.Sequential(model.layers[5:])
model_generator.build((None, 8))
model_generator.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= reshape_1 (Reshape) (None, 2, 4) 0 _________________________________________________________________ decoder_layer_1 (LSTM) (None, 2, 16) 1344 _________________________________________________________________ decoder_layer_2 (LSTM) (None, 2, 32) 6272 _________________________________________________________________ decoder_layer_3 (LSTM) (None, 64) 24832 _________________________________________________________________ final_layer (Dense) (None, 12288) 798720 _________________________________________________________________ reshape_2 (Reshape) (None, 64, 64, 3) 0 ================================================================= Total params: 831,168 Trainable params: 831,168 Non-trainable params: 0 _________________________________________________________________
import numpy as np
inputs = np.random.normal(code_stats['mean'], code_stats['std'], (16, 8))
# inputs = codes
image_generated = model_generator.predict(inputs, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
plot_images(image_generated, n_col=8)
This would be similar to Dense n/w as desribed above, but we will use CNN layers this time
# from numba import cuda
# device = cuda.get_current_device()
# device.reset()
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.callbacks import ModelCheckpoint
model_file = 'model_ae_cnn.h5'
filter_size = (3, 3)
model = keras.Sequential(name="autoencoder_cnn")
model.add(tf.keras.layers.InputLayer(input_shape=images_shape))
# model.add(tf.keras.layers.Conv2D(256, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_1'))
model.add(tf.keras.layers.Conv2D(128, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_1'))
model.add(tf.keras.layers.Conv2D(64, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_2'))
model.add(tf.keras.layers.Conv2D(32, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_3'))
model.add(tf.keras.layers.Conv2D(16, filter_size, activation='relu', padding='same', strides=2, name='encoder_layer_4'))
model.add(layers.Flatten())
model.add(layers.Dense(8, name="code"))
model.add(layers.Reshape((2, 2, 2)))
model.add(tf.keras.layers.Conv2DTranspose(16, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_1'))
model.add(tf.keras.layers.Conv2DTranspose(32, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_2'))
model.add(tf.keras.layers.Conv2DTranspose(64, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_3'))
model.add(tf.keras.layers.Conv2DTranspose(128, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_4'))
# model.add(tf.keras.layers.Conv2DTranspose(256, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_5'))
model.add(tf.keras.layers.Conv2DTranspose(3, filter_size, activation='relu', padding='same', strides=2, name='decoder_layer_6'))
checkpoint = ModelCheckpoint(model_file, verbose=0, monitor='val_loss', save_best_only=True, mode='auto')
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=5, restore_best_weights=True)
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
loss=tf.keras.losses.MeanSquaredError(),
metrics=['mse']
)
model.summary()
Model: "autoencoder_cnn" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= encoder_layer_1 (Conv2D) (None, 32, 32, 128) 3584 _________________________________________________________________ encoder_layer_2 (Conv2D) (None, 16, 16, 64) 73792 _________________________________________________________________ encoder_layer_3 (Conv2D) (None, 8, 8, 32) 18464 _________________________________________________________________ encoder_layer_4 (Conv2D) (None, 4, 4, 16) 4624 _________________________________________________________________ flatten (Flatten) (None, 256) 0 _________________________________________________________________ code (Dense) (None, 8) 2056 _________________________________________________________________ reshape (Reshape) (None, 2, 2, 2) 0 _________________________________________________________________ decoder_layer_1 (Conv2DTrans (None, 4, 4, 16) 304 _________________________________________________________________ decoder_layer_2 (Conv2DTrans (None, 8, 8, 32) 4640 _________________________________________________________________ decoder_layer_3 (Conv2DTrans (None, 16, 16, 64) 18496 _________________________________________________________________ decoder_layer_4 (Conv2DTrans (None, 32, 32, 128) 73856 _________________________________________________________________ decoder_layer_6 (Conv2DTrans (None, 64, 64, 3) 3459 ================================================================= Total params: 203,275 Trainable params: 203,275 Non-trainable params: 0 _________________________________________________________________
%%time
model.fit(images_train, images_train, batch_size=32, epochs=500, validation_split=0.2, callbacks=[checkpoint, early_stopping], shuffle=True)
model.save(model_file) # Save Best model to disk
Epoch 1/500 1113/1113 [==============================] - 42s 14ms/step - loss: 3058.9998 - mse: 3058.9998 - val_loss: 2511.7271 - val_mse: 2511.7271 Epoch 2/500 1113/1113 [==============================] - 15s 14ms/step - loss: 2468.7471 - mse: 2468.7471 - val_loss: 2443.8110 - val_mse: 2443.8110 Epoch 3/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2411.2112 - mse: 2411.2112 - val_loss: 2394.8528 - val_mse: 2394.8528 Epoch 4/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2384.5640 - mse: 2384.5640 - val_loss: 2383.0754 - val_mse: 2383.0754 Epoch 5/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2369.5110 - mse: 2369.5110 - val_loss: 2363.0452 - val_mse: 2363.0452 Epoch 6/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2359.4460 - mse: 2359.4460 - val_loss: 2373.9758 - val_mse: 2373.9758 Epoch 7/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2339.5972 - mse: 2339.5972 - val_loss: 2325.3003 - val_mse: 2325.3003 Epoch 8/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2319.3804 - mse: 2319.3804 - val_loss: 2307.9702 - val_mse: 2307.9702 Epoch 9/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2300.3269 - mse: 2300.3269 - val_loss: 2301.1951 - val_mse: 2301.1951 Epoch 10/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2286.0310 - mse: 2286.0310 - val_loss: 2281.6335 - val_mse: 2281.6335 Epoch 11/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2275.1306 - mse: 2275.1306 - val_loss: 2270.2104 - val_mse: 2270.2104 Epoch 12/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2264.5574 - mse: 2264.5574 - val_loss: 2274.3042 - val_mse: 2274.3042 Epoch 13/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2257.6902 - mse: 2257.6902 - val_loss: 2270.7708 - val_mse: 2270.7708 Epoch 14/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2253.0427 - mse: 2253.0427 - val_loss: 2251.2915 - val_mse: 2251.2915 Epoch 15/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2247.2974 - mse: 2247.2974 - val_loss: 2247.7493 - val_mse: 2247.7493 Epoch 16/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2241.0981 - mse: 2241.0981 - val_loss: 2249.4192 - val_mse: 2249.4192 Epoch 17/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2237.9487 - mse: 2237.9487 - val_loss: 2238.3975 - val_mse: 2238.3975 Epoch 18/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2231.9199 - mse: 2231.9199 - val_loss: 2270.9529 - val_mse: 2270.9529 Epoch 19/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2227.9573 - mse: 2227.9573 - val_loss: 2248.0601 - val_mse: 2248.0601 Epoch 20/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2223.6274 - mse: 2223.6274 - val_loss: 2235.0117 - val_mse: 2235.0117 Epoch 21/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2219.8855 - mse: 2219.8855 - val_loss: 2221.4517 - val_mse: 2221.4517 Epoch 22/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2214.5239 - mse: 2214.5239 - val_loss: 2222.6326 - val_mse: 2222.6326 Epoch 23/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2209.3633 - mse: 2209.3633 - val_loss: 2233.6631 - val_mse: 2233.6631 Epoch 24/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2203.8315 - mse: 2203.8315 - val_loss: 2217.5784 - val_mse: 2217.5784 Epoch 25/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2197.1934 - mse: 2197.1934 - val_loss: 2205.1272 - val_mse: 2205.1272 Epoch 26/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2191.9189 - mse: 2191.9189 - val_loss: 2202.4661 - val_mse: 2202.4661 Epoch 27/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2185.0027 - mse: 2185.0027 - val_loss: 2192.4851 - val_mse: 2192.4851 Epoch 28/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2176.2422 - mse: 2176.2422 - val_loss: 2181.2500 - val_mse: 2181.2500 Epoch 29/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2169.8945 - mse: 2169.8945 - val_loss: 2177.0142 - val_mse: 2177.0142 Epoch 30/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2163.1399 - mse: 2163.1399 - val_loss: 2170.9331 - val_mse: 2170.9331 Epoch 31/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2158.5430 - mse: 2158.5430 - val_loss: 2165.5391 - val_mse: 2165.5391 Epoch 32/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2153.6353 - mse: 2153.6353 - val_loss: 2169.0059 - val_mse: 2169.0059 Epoch 33/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2150.2375 - mse: 2150.2375 - val_loss: 2184.1431 - val_mse: 2184.1431 Epoch 34/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2146.2983 - mse: 2146.2983 - val_loss: 2160.8652 - val_mse: 2160.8652 Epoch 35/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2141.6951 - mse: 2141.6951 - val_loss: 2162.7627 - val_mse: 2162.7627 Epoch 36/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2140.3684 - mse: 2140.3684 - val_loss: 2148.9524 - val_mse: 2148.9524 Epoch 37/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2138.5645 - mse: 2138.5645 - val_loss: 2149.0491 - val_mse: 2149.0491 Epoch 38/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2136.0417 - mse: 2136.0417 - val_loss: 2139.4277 - val_mse: 2139.4277 Epoch 39/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2134.2087 - mse: 2134.2087 - val_loss: 2147.2849 - val_mse: 2147.2849 Epoch 40/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2132.8669 - mse: 2132.8669 - val_loss: 2142.0872 - val_mse: 2142.0872 Epoch 41/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2128.8611 - mse: 2128.8611 - val_loss: 2143.2449 - val_mse: 2143.2449 Epoch 42/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2128.0315 - mse: 2128.0315 - val_loss: 2138.1052 - val_mse: 2138.1052 Epoch 43/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2125.8042 - mse: 2125.8042 - val_loss: 2137.4983 - val_mse: 2137.4983 Epoch 44/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2126.8896 - mse: 2126.8896 - val_loss: 2135.8518 - val_mse: 2135.8518 Epoch 45/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2125.1531 - mse: 2125.1531 - val_loss: 2145.0603 - val_mse: 2145.0603 Epoch 46/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2123.4487 - mse: 2123.4487 - val_loss: 2131.9419 - val_mse: 2131.9419 Epoch 47/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2121.2905 - mse: 2121.2905 - val_loss: 2146.5078 - val_mse: 2146.5078 Epoch 48/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2119.4956 - mse: 2119.4956 - val_loss: 2132.5427 - val_mse: 2132.5427 Epoch 49/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2118.7007 - mse: 2118.7007 - val_loss: 2127.1130 - val_mse: 2127.1130 Epoch 50/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2117.9265 - mse: 2117.9265 - val_loss: 2132.1787 - val_mse: 2132.1787 Epoch 51/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2117.5581 - mse: 2117.5581 - val_loss: 2129.9849 - val_mse: 2129.9849 Epoch 52/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2117.7019 - mse: 2117.7019 - val_loss: 2128.9792 - val_mse: 2128.9792 Epoch 53/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2115.8765 - mse: 2115.8765 - val_loss: 2128.2673 - val_mse: 2128.2673 Epoch 54/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2113.9507 - mse: 2113.9507 - val_loss: 2121.6313 - val_mse: 2121.6313 Epoch 55/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2113.0039 - mse: 2113.0039 - val_loss: 2128.6384 - val_mse: 2128.6384 Epoch 56/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2110.9111 - mse: 2110.9111 - val_loss: 2128.9453 - val_mse: 2128.9453 Epoch 57/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2111.1548 - mse: 2111.1548 - val_loss: 2123.9087 - val_mse: 2123.9087 Epoch 58/500 1113/1113 [==============================] - 16s 14ms/step - loss: 2110.3159 - mse: 2110.3159 - val_loss: 2124.1143 - val_mse: 2124.1143 Epoch 59/500 1113/1113 [==============================] - 16s 15ms/step - loss: 2112.7029 - mse: 2112.7029 - val_loss: 2127.9968 - val_mse: 2127.9968 CPU times: user 14min 33s, sys: 53.8 s, total: 15min 26s Wall time: 16min 17s
!mkdir -p drive/MyDrive/datasets/autoencoder/models_animefaces
!cp model_ae_cnn.h5 drive/MyDrive/datasets/autoencoder/models_animefaces
!ls -lh drive/MyDrive/datasets/autoencoder/models_animefaces
total 50M -rw------- 1 root root 2.5M Jun 6 08:16 model_ae_cnn.h5 -rw------- 1 root root 37M Jun 5 14:56 model_ae_dnn.h5 -rw------- 1 root root 11M Jun 6 07:45 model_ae_lstm.h5
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_cnn.h5'
# model.load_weights(model_file) # Load best model
model = tf.keras.models.load_model(model_file) # Load entire model
model.evaluate(images_test, images_test, batch_size=8, verbose=True)
2384/2384 [==============================] - 10s 4ms/step - loss: 2117.5776 - mse: 2117.5776
[2117.57763671875, 2117.57763671875]
def display_accuracy(model, image_actual, n_col=4, text=""):
print("=================================== %s ===============================" % text)
image_generated = model.predict(image_actual, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
images_side_by_side = np.concatenate([image_actual, image_generated], axis=2)
plot_images(images_side_by_side, n_col=n_col)
images_to_display = 16
display_accuracy(model, images_train[:images_to_display], text="Train Output")
display_accuracy(model, images_test[:images_to_display], text="Prediction Output")
=================================== Train Output ===============================
=================================== Prediction Output ===============================
from tensorflow import keras
# Layers to be used
layers = [tf.keras.layers.InputLayer(input_shape=images_shape)]
layers.extend(model.layers[:6])
model_code_generator = keras.Sequential(layers)
model_code_generator.build((None, images_shape[0], images_shape[1], images_shape[2]))
for layer in model_code_generator.layers:
if list(filter(lambda x: x in layer.name, ['flatten', 'reshape'])):
continue
assert all([np.array_equal(layer.get_weights()[0], model.get_layer(layer.name).get_weights()[0]),
np.array_equal(layer.get_weights()[1], model.get_layer(layer.name).get_weights()[1])]), "%s weights not same" % layer.name
model_code_generator.summary()
Model: "sequential_1" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= encoder_layer_1 (Conv2D) (None, 32, 32, 128) 3584 _________________________________________________________________ encoder_layer_2 (Conv2D) (None, 16, 16, 64) 73792 _________________________________________________________________ encoder_layer_3 (Conv2D) (None, 8, 8, 32) 18464 _________________________________________________________________ encoder_layer_4 (Conv2D) (None, 4, 4, 16) 4624 _________________________________________________________________ flatten_5 (Flatten) (None, 256) 0 _________________________________________________________________ code (Dense) (None, 8) 2056 ================================================================= Total params: 102,520 Trainable params: 102,520 Non-trainable params: 0 _________________________________________________________________
codes = model_code_generator.predict(images_test[:16], batch_size=8, verbose=False)
codes.shape
(16, 8)
print(codes[0].tolist())
print(codes[1].tolist())
print(codes[2].tolist())
[-90.84349060058594, 78.14624786376953, 61.19371795654297, 47.44902038574219, 135.2177734375, 48.00750732421875, -166.69300842285156, -8.807793617248535] [-82.30125427246094, 82.98831939697266, 61.37909698486328, 43.42222213745117, 97.01995849609375, 58.46400833129883, -44.247249603271484, 2.3804984092712402] [-55.21834945678711, 40.70209503173828, 36.956214904785156, 42.93532180786133, 128.26377868652344, 51.0125846862793, -89.1243667602539, -20.50321388244629]
code_stats = {
"min" : np.min(codes),
"max" : np.max(codes),
"mean": np.mean(codes),
"std": np.std(codes)
}
code_stats
{'max': 136.47034, 'mean': 27.564598, 'min': -187.67128, 'std': 69.36067}
But we need to remove some extra layers before that, now we know that code layer has 8 neurons. So we are going to generate some random 8 numbers and will pass it to out decoder layer
import tensorflow as tf
model_file = '/content/drive/MyDrive/datasets/autoencoder/models_animefaces/model_ae_cnn.h5'
model = tf.keras.models.load_model(model_file) # Load entire model
# model.summary()
from tensorflow import keras
model_generator = keras.Sequential(model.layers[6:])
model_generator.build((None, 8))
model_generator.summary()
Model: "sequential_4" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= reshape_8 (Reshape) (None, 2, 2, 2) 0 _________________________________________________________________ decoder_layer_1 (Conv2DTrans (None, 4, 4, 16) 304 _________________________________________________________________ decoder_layer_2 (Conv2DTrans (None, 8, 8, 32) 4640 _________________________________________________________________ decoder_layer_3 (Conv2DTrans (None, 16, 16, 64) 18496 _________________________________________________________________ decoder_layer_4 (Conv2DTrans (None, 32, 32, 128) 73856 _________________________________________________________________ decoder_layer_6 (Conv2DTrans (None, 64, 64, 3) 3459 ================================================================= Total params: 100,755 Trainable params: 100,755 Non-trainable params: 0 _________________________________________________________________
import numpy as np
inputs = np.random.normal(code_stats['mean'], code_stats['std'], (16, 8))
# inputs = codes
image_generated = model_generator.predict(inputs, batch_size=8, verbose=False).astype(np.uint8)
image_generated[image_generated > 255] = 255
image_generated[image_generated < 0] = 0
plot_images(image_generated, n_col=8)